"""
Faraday Penetration Test IDE
Copyright (C) 2016  Infobyte LLC (https://faradaysec.com/)
See the file 'doc/LICENSE' for the license information
"""

# Standard library imports
import json
import re
import threading
import logging
import csv
from io import TextIOWrapper

# Related third party imports
import wtforms
from filteralchemy import (
    FilterSet,
    operators,
)
from flask import (
    Blueprint,
    request,
    abort,
    jsonify,
    make_response,
)
from flask_classful import route
from flask_wtf.csrf import validate_csrf
from marshmallow import fields, ValidationError, post_load, pre_load, post_dump
from marshmallow.validate import OneOf
from werkzeug.exceptions import Conflict

# Local application imports
from faraday.server.api.base import (
    AutoSchema,
    FilterAlchemyMixin,
    FilterSetMeta,
    PaginatedMixin,
    ReadWriteView,
    FilterMixin,
    BulkDeleteMixin,
    BulkUpdateMixin,
)
from faraday.server.api.modules.vulns_base import ImpactSchema
from faraday.server.schemas import (
    PrimaryKeyRelatedField,
    SeverityField,
    SelfNestedField,
    FaradayCustomField,
)
from faraday.server.models import (
    db,
    CustomFieldsSchema,
    Vulnerability,
    VulnerabilityTemplate,
)

vulnerability_template_api = Blueprint('vulnerability_template_api', __name__)
logger = logging.getLogger(__name__)


class CVSS2Schema(AutoSchema):
    vector_string = fields.String(attribute="cvss2_vector_string", required=False, allow_none=True)


class CVSS3Schema(AutoSchema):
    vector_string = fields.String(attribute="cvss3_vector_string", required=False, allow_none=True)


class CVSS4Schema(AutoSchema):
    vector_string = fields.String(attribute="cvss4_vector_string", required=False, allow_none=True)


class VulnerabilityTemplateSchema(AutoSchema):
    _id = fields.Integer(dump_only=True, attribute='id')
    id = fields.Integer(dump_only=True, attribute='id')
    _rev = fields.String(default='', dump_only=True)
    cwe = fields.String(dump_only=True, default='')  # deprecated field, the legacy data is added to refs on import
    exploitation = SeverityField(attribute='severity', required=True)
    references = fields.Method('get_references', deserialize='load_references')
    refs = fields.List(fields.String(), dump_only=True, attribute='references')
    desc = fields.String(dump_only=True, attribute='description')
    data = fields.String(attribute='data')
    impact = SelfNestedField(ImpactSchema())
    easeofresolution = fields.String(
        attribute='ease_of_resolution',
        validate=OneOf(Vulnerability.EASE_OF_RESOLUTIONS),
        allow_none=True)
    policyviolations = fields.List(fields.String,
                                   attribute='policy_violations')
    creator = PrimaryKeyRelatedField('username', dump_only=True, attribute='creator')
    creator_id = fields.Integer(dump_only=True, attribute='creator_id')

    create_at = fields.DateTime(attribute='create_date',
                                dump_only=True)
    cve = fields.String(default="", required=False, allow_none=True)
    cvss2 = SelfNestedField(CVSS2Schema(), required=False, allow_none=True)
    cvss3 = SelfNestedField(CVSS3Schema(), required=False, allow_none=True)
    cvss4 = SelfNestedField(CVSS4Schema(), required=False, allow_none=True)

    # Here we use vulnerability instead of vulnerability_template to avoid duplicate row
    # in the custom_fields_schema table.
    # All validation will be against vulnerability table.
    external_id = fields.String(allow_none=True)
    customfields = FaradayCustomField(table_name='vulnerability', attribute='custom_fields')

    class Meta:
        model = VulnerabilityTemplate
        fields = ('id', '_id', '_rev', 'cwe', 'description', 'desc',
                  'exploitation', 'name', 'references', 'refs', 'resolution',
                  'impact', 'easeofresolution', 'policyviolations', 'data',
                  'external_id', 'creator', 'create_at', 'creator_id',
                  'customfields', 'cve', 'cvss2', 'cvss3', 'cvss4')

    @staticmethod
    def get_references(obj):
        return ', '.join(map(lambda ref_tmpl: ref_tmpl.name, obj.reference_template_instances))

    @staticmethod
    def load_references(value):
        if isinstance(value, bytes):
            value = value.decode('utf-8')
        if isinstance(value, list):
            references = value
        elif isinstance(value, str):
            if len(value) == 0:
                # Required because "".split(",") == [""]
                return []
            references = [ref.strip() for ref in value.split(',')]
        else:
            raise ValidationError('references must be a either a string '
                                  'or a list')
        if any(len(ref) == 0 for ref in references):
            raise ValidationError('Empty name detected in reference')
        return references

    @post_load
    def post_load_impact(self, data, **kwargs):
        # Unflatten impact (move data[impact][*] to data[*])
        impact = data.pop('impact', None)
        if impact:
            data.update(impact)
        return data

    @post_load
    def post_load_cvss_fields(self, data, **kwargs):
        for version in ['cvss2', 'cvss3', 'cvss4']:
            if version in data:
                vector_string = f'{version}_vector_string'
                cvss = data.pop(version)
                if vector_string in cvss:
                    data[vector_string] = cvss[vector_string]
        return data

    @pre_load
    def pre_load_cve(self, data, **kwargs):
        if "cve" in data:
            # Check if the CVE is an empty list
            if isinstance(data["cve"], list) and len(data["cve"]) == 0:
                data["cve"] = None
                return data

            regex = r'CVE-\d{4}-\d{4,7}'

            cve_list = data["cve"]
            validated_list = []

            for cve in cve_list:
                if re.match(regex, cve):
                    validated_list.append(cve)

            if len(validated_list) == 0:
                data["cve"] = []
            else:
                data["cve"] = ",".join(validated_list)
        return data

    @post_dump
    def post_dump_cve(self, data, **kwargs):
        if "cve" in data:
            if data["cve"] == "" or data["cve"] is None:
                data["cve"] = []
            else:
                data["cve"] = data["cve"].split(",")
        return data


class VulnerabilityTemplateFilterSet(FilterSet):
    class Meta(FilterSetMeta):
        model = VulnerabilityTemplate  # It has all the fields
        fields = (
            'severity')
        operators = (operators.Equal,)


lock = threading.Lock()


class VulnerabilityTemplateView(PaginatedMixin,
                                FilterAlchemyMixin,
                                ReadWriteView,
                                FilterMixin,
                                BulkDeleteMixin,
                                BulkUpdateMixin):
    route_base = 'vulnerability_template'
    model_class = VulnerabilityTemplate
    schema_class = VulnerabilityTemplateSchema
    filterset_class = VulnerabilityTemplateFilterSet
    get_joinedloads = [VulnerabilityTemplate.creator]

    def _envelope_list(self, objects, pagination_metadata=None):
        return {
            'rows': objects,
            'count': (pagination_metadata.total
                      if pagination_metadata is not None else len(objects))
        }

    def post(self, **kwargs):
        """
        ---
        post:
          tags: ["VulnerabilityTemplate"]
          summary: Creates VulnerabilityTemplate
          requestBody:
            required: true
            content:
              application/json:
                schema: VulnerabilityTemplateSchema
          responses:
            201:
              description: Created
              content:
                application/json:
                  schema: VulnerabilityTemplateSchema
            409:
              description: Duplicated key found
              content:
                application/json:
                  schema: VulnerabilityTemplateSchema
        """
        with lock:
            return super().post(**kwargs)

    def _get_schema_instance(self, route_kwargs, **kwargs):
        schema = super()._get_schema_instance(
            route_kwargs, **kwargs)

        return schema

    @route('/bulk_create', methods=['POST'])
    def bulk_create(self):
        """
        ---
        post:
          tags: ["Bulk", "VulnerabilityTemplate"]
          description: Creates Vulnerability templates in bulk
          responses:
            201:
              description: Created
              content:
                application/json:
                  schema: VulnerabilityTemplateSchema
            400:
              description: Bad request
            403:
              description: Forbidden
        tags: ["Bulk", "VulnerabilityTemplate"]
        responses:
          200:
            description: Ok
        """
        vulns_to_create, status_code = None, None
        csrf_token = request.form.get('csrf_token', '')
        if not csrf_token:
            csrf_token = request.json.get('csrf_token', '')
        try:
            validate_csrf(csrf_token)
        except wtforms.ValidationError:
            logger.error("Invalid CSRF token.")
            abort(make_response({"message": "Invalid CSRF token."}, 403))

        if 'file' in request.files:
            logger.info("Create vulns template from CSV")
            vulns_file = request.files['file']

            io_wrapper = TextIOWrapper(vulns_file, encoding=request.content_encoding or "utf8")
            vulns_reader = csv.DictReader(io_wrapper, skipinitialspace=True)

            required_headers = {'name', 'exploitation'}
            diff_required = required_headers.difference(set(vulns_reader.fieldnames))
            if diff_required:
                logger.error(f"Missing required headers in CSV: {diff_required}")
                abort(
                    make_response(
                        {"message": f"Missing required headers in CSV: {diff_required}"}, 400
                    )
                )

            vulns_to_create = self._parse_vuln_from_file(vulns_reader)
        elif request.json.get('vulns'):
            logger.info("Create vulns template from vulnerabilities in Status Report")

            vulns_to_create = request.json.get('vulns')
            for vuln in vulns_to_create:
                # Due to the definition in the model, we need to
                # rename 'custom_fields' attribute to 'customfields'
                vuln['customfields'] = vuln.get('custom_fields', {})
        else:
            logger.error("Missing data to create vulnerabilities templates.")
            abort(make_response({"message": "Missing data to create vulnerabilities templates."}, 400))

        if not vulns_to_create:
            logger.error("Missing data to create vulnerabilities templates.")
            abort(make_response({"message": "Missing data to create vulnerabilities templates."}, 400))

        vulns_created = []
        vulns_with_errors = []
        vulns_with_conflict = []
        schema = self.schema_class()
        for vuln in vulns_to_create:
            try:
                vuln_schema = schema.load(vuln)
                super()._perform_create(vuln_schema)
                db.session.commit()
            except ValidationError as e:
                vulns_with_errors.append((vuln.get('_id', ''), vuln['name']))
            except Conflict:
                vulns_with_conflict.append((vuln.get('_id', ''), vuln['name']))
            else:
                vulns_created.append((vuln.get('_id', ''), vuln['name']))

        if vulns_created:
            status_code = 200
        elif not vulns_created and vulns_with_conflict:
            status_code = 409
        elif not vulns_created and vulns_with_errors:
            status_code = 400
        logger.info("Vulnerability templates created in bulk")
        return make_response(
            jsonify(vulns_created=vulns_created,
                    vulns_with_errors=vulns_with_errors,
                    vulns_with_conflict=vulns_with_conflict),
            status_code
        )

    @staticmethod
    def _parse_vuln_from_file(vulns_reader):
        custom_fields = {cf_schema.field_name: cf_schema for cf_schema in db.session.query(CustomFieldsSchema).all()}
        vulns_list = []
        for index, vuln_dict in enumerate(vulns_reader):
            vuln_dict['customfields'] = {}
            vuln_dict['impact'] = {}
            for key in vuln_dict.keys():
                if key in custom_fields.keys():
                    if custom_fields[key].field_type == 'list' and vuln_dict[key]:
                        custom_field_value = vuln_dict[key].replace('‘', '"').replace('’', '"')
                        try:
                            vuln_dict['customfields'][key] = json.loads(custom_field_value)
                        except ValueError:
                            logger.warning(f'Invalid list for custom field {key}. '
                                           f'Faraday will skip this custom field.')
                    elif custom_fields[key].field_type == 'choice' and vuln_dict[key]:
                        cf_choices = custom_fields[key].field_metadata
                        if isinstance(cf_choices, str):
                            cf_choices = json.loads(cf_choices)
                        if vuln_dict[key] not in cf_choices:
                            logger.warning(f'Invalid choice for custom field {key}. '
                                           f'Faraday will skip this custom field.')
                        else:
                            vuln_dict['customfields'][key] = vuln_dict[key]
                    else:
                        vuln_dict['customfields'][key] = vuln_dict[key]

            vuln_dict['impact']['accountability'] = vuln_dict.get('accountability', False)
            vuln_dict['impact']['availability'] = vuln_dict.get('availability', False)
            vuln_dict['impact']['confidentiality'] = vuln_dict.get('confidentiality', False)
            vuln_dict['impact']['integrity'] = vuln_dict.get('integrity', False)
            vulns_list.append(vuln_dict)

        return vulns_list


VulnerabilityTemplateView.register(vulnerability_template_api)
